import numpy as np
import torch
from torch._C import device
import torch.optim as optim
from multiprocessing import Pool
import dataloader
import model.model_ELIS as model_ELIS
import paramENC as paramzzl

import tool

def Train(args, Model, data, target, optimizer, epoch):

    BATCH_SIZE = args['batch_size']

    Model.train()

    num_train_sample = data.shape[0]
    num_batch = (num_train_sample - 0.5) // BATCH_SIZE + 1
    rand_index_i = torch.randperm(num_train_sample, device=Model.device).long()
    train_loss_sum = [0, 0, 0, 0, 0, 0, 0]

    for batch_idx in torch.arange(0, num_batch):
        start = (batch_idx * BATCH_SIZE).int().to(Model.device)
        end = torch.min(
            torch.tensor(
                [batch_idx * BATCH_SIZE + BATCH_SIZE,
                 num_train_sample])).to(Model.device)
        sample_index_i = rand_index_i[start:end.int()]

        optimizer.zero_grad()
        output = Model(sample_index_i)
        loss_list = Model.Loss(output, sample_index_i)
        loss_list[0].backward()
        loss_list[1].backward()
        train_loss_sum[0] += loss_list[0].item()
        train_loss_sum[1] += loss_list[1].item()

        if (args['trainquiet'] == 0) and (epoch % 100 == 0):
            print('batch {} loss {}'.format(batch_idx, loss_list[0].item()))

        optimizer.step()

    if args['trainquiet'] == 0 and (epoch % 100 == 0):
        print('Train Epoch: {} [{}/{} ({:.0f}%)] \t Loss: {}'.format(
            epoch, batch_idx * BATCH_SIZE, num_train_sample,
            BATCH_SIZE * 100. * batch_idx / num_train_sample, train_loss_sum))
        print(Model.vList[-1])

    return train_loss_sum


def Test(args, Model, data, target, optimizer, epoch):

    Model.eval()
    BATCH_SIZE = args['batch_size']
    num_train_sample = data.shape[0]
    num_batch = (num_train_sample - 0.5) // BATCH_SIZE + 1
    rand_index_i = torch.arange(num_train_sample)

    for batch_idx in torch.arange(0, num_batch):
        start = (batch_idx * BATCH_SIZE).int()
        end = torch.min(
            torch.tensor(
                [batch_idx * BATCH_SIZE + BATCH_SIZE, num_train_sample]))
        sample_index_i = rand_index_i[start:end.int()]

        datab = data.float()[sample_index_i]
        em = Model.test(datab)
        re = Model.Generate(em[-1])
        # print(em[0])

        em = em[-1].detach().cpu().numpy()
        re = re[-1].detach().cpu().numpy()
        if batch_idx == 0:
            outem = em
            outre = re
        else:
            outem = np.concatenate((outem, em), axis=0)
            outre = np.concatenate((outre, re), axis=0)


    return outem, outre


def main(args, pool=None,):

    # args = paramzzl.GetParamMnistL()
    path = tool.GetPath(args['data_name']+'_'+args['name'])
    tool.SaveParam(path, args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    data_train, label_train = dataloader.GetData(
        args,
        device=device,
    )
    data_test, label_test = dataloader.GetData(
        args,
        device=device,
        testBool=True
    )
    tool.SetSeed(args['seed'])
    gifPloterLatentTrain = tool.GIFPloter()
    gifPloterLatentTest = tool.GIFPloter()
    gifPloterrecons = tool.GIFPloter()
    Model = model_ELIS.LISV2_MLP(
        data_train,
        device=device,
        args=args,
        path=path
    ).to(device)

    optimizer = optim.Adam(Model.parameters(), lr=args['lr'])

    loss_his = []
    for epoch in range(0, args['epochs'] + 1):
        loss_item = Train(args, Model, data_train, label_train, optimizer,
                          epoch)
        
        Model.epoch = epoch
        loss_his.append(loss_item)

        if epoch == 500:
            for param_group in optimizer.param_groups:
                param_group["lr"] = param_group["lr"] / 10
            print('change the learning rate')

        if epoch % args['log_interval'] == 0:

            em_train, re_train = Test(args, Model, data_train, label_train,
                                      optimizer, epoch)
            gifPloterLatentTrain.AddNewFig(
                em_train,
                label_train.detach().cpu(),
                his_loss=loss_his,
                path=path,
                graph=Model.GetInput(),
                link=None,
                title_='train_epoch_em{}_{}{}.png'.format(
                    epoch, args['perplexity'], args['v']))

            em_test, re_test = Test(args, Model, data_test, label_test,
                                      optimizer, epoch)
            gifPloterLatentTest.AddNewFig(
                em_test,
                label_test.detach().cpu(),
                his_loss=loss_his,
                path=path,
                graph=Model.GetInput(),
                link=None,
                title_='test_epoch_em{}_{}{}.png'.format(
                    epoch, args['perplexity'], args['v']))

            if args['Dec']:
                if epoch == 0:
                    input_data = data_train.detach().cpu().numpy()
                else:
                    input_data = re_train
                gifPloterrecons.AddNewFig(
                    input_data,
                    label_train.detach().cpu(),
                    his_loss=loss_his,
                    path=path,
                    graph=Model.GetInput(),
                    link=None,
                    title_='train_epoch_re{}_{}{}.png'.format(
                        epoch, args['perplexity'], args['v']),
                    dataset=args['name'])

            tool.SaveData(
                data_train,
                em_train,
                label_train,
                # dist=Model.dist,
                path=path,
                name='train_epoch{}'.format(str(epoch).zfill(6)))

    gifPloterLatentTrain.SaveGIF()

    return path


if __name__ == "__main__":

    args = paramzzl.GetParamswishroll()
    path = main(args)
    args = paramzzl.GetParamScurve()
    path = main(args)
    args = paramzzl.GetParamseveredsphere()
    path = main(args)
    args = paramzzl.GetParamSphere5500()
    path = main(args)
    args = paramzzl.GetParamSphere10000()
    path = main(args)
    args = paramzzl.GetParamcoil20()
    path = main(args)
    args = paramzzl.GetParamcoil100rgb()
    path = main(args)
    args = paramzzl.GetParamMnistL()
    path = main(args)
    args = paramzzl.GetParamFMnistL()
    path = main(args)
    
